import time
import random
import json
import numpy as np
import scipy.sparse as sp 
import imageio
import os
import argparse
import tqdm
import glob
import trimesh

from subprocess import call
from tree import PartTree, save_tree, load_tree
from partnet_renderer.render import NVDiffRasterizer

height = 400
width = 400

def load_obj(fn):
    fin = open(fn, 'r')
    lines = [line.rstrip() for line in fin]
    fin.close()

    vertices = []; faces = [];
    for line in lines:
        if line.startswith('v '):
            vertices.append(np.float32(line.split()[1:4]))
        elif line.startswith('f '):
            faces.append(np.int32([item.split('/')[0] for item in line.split()[1:4]]))

    f = np.vstack(faces)
    v = np.vstack(vertices)

    return v, f

def export_obj(out, v, f, color):
    mtl_out = out.replace('.obj', '.mtl')

    with open(out, 'w') as fout:
        fout.write('mtllib %s\n' % mtl_out)
        fout.write('usemtl m1\n')
        for i in range(v.shape[0]):
            fout.write('v %f %f %f\n' % (v[i, 0], v[i, 1], v[i, 2]))
        for i in range(f.shape[0]):
            fout.write('f %d %d %d\n' % (f[i, 0], f[i, 1], f[i, 2]))

    with open(mtl_out, 'w') as fout:
        fout.write('newmtl m1\n')
        fout.write('Kd %f %f %f\n' % (color[0], color[1], color[2]))
        fout.write('Ka 0 0 0\n')

    return mtl_out

def render_mesh(rast):
    # Render using nvdiff rast
    # Output both image rendering + GT segmentation

    control_dict = rast()

    rgb = control_dict["comp_rgb"].squeeze().detach().cpu().numpy() # H W C
    mask = control_dict["mask"].squeeze().detach().cpu().numpy() # H W
    #
    # import matplotlib.pyplot as plt
    #
    # fig, ax = plt.subplots()
    #
    # # Plot the 2D array
    # ax.imshow(mask, cmap='viridis')
    # print(mask)
    # plt.show()

    return rgb, mask

def create_partnet_tree(partnet_dir, output) -> PartTree:
    cur_shape_dir = os.path.join(partnet_dir, output)
    cur_part_dir = os.path.join(cur_shape_dir, 'objs')
    leaf_part_ids = [item.split('.')[0] for item in os.listdir(cur_part_dir) if item.endswith('.obj')]

    cur_render_dir = os.path.join(cur_shape_dir, 'parts_render_tree')
    if not os.path.exists(cur_render_dir):
        os.mkdir(cur_render_dir)

    root_v_list = []; root_f_list = []; tot_v_num = 0;
    for idx in leaf_part_ids:
        v, f = load_obj(os.path.join(cur_part_dir, str(idx)+'.obj'))
        mesh = dict();
        mesh['v'] = v; mesh['f'] = f;
        root_v_list.append(v);
        root_f_list.append(f+tot_v_num);
        tot_v_num += v.shape[0];

    root_v = np.vstack(root_v_list)
    root_f = np.vstack(root_f_list)

    center = np.mean(root_v, axis=0)
    root_v -= center
    scale = np.sqrt(np.max(np.sum(root_v**2, axis=1))) * 1.5
    root_v /= scale

    # TODO: create nvdiffrast object
    rast = NVDiffRasterizer(cur_part_dir, center, scale, color=[0.93, 0.0, 0.0])
    root_render, _ = render_mesh(rast)

    cur_result_json = os.path.join(cur_shape_dir, 'result.json')
    with open(cur_result_json, 'r') as fin:
        tree_hier = json.load(fin)[0]

    # Create part hierarchy tree
    output_dir = f'vis/partnet/{output}/tree'
    parttree = None

    def render(data):
        masked_segments = []

        nonlocal parttree
        if parttree is None:
            parttree = PartTree(output_dir, data['id'])

        node = parttree.label_to_node[ data['id'] ]
        node.gt_caption = data['text']

        if 'objs' in data.keys(): # leaf node
            for child in data['objs']:
                masked_segments.append(child+'.obj')
        elif 'children' in data.keys(): # parent node
            for child in data['children']:
                parttree.add_edge(child['id'], None, None, data['id'])

                masked_segments.extend(render(child))

        else:
            return

        # TODO: set mesh
        rast.set_mesh(masked_segments)
        part_render, part_mask = render_mesh(rast) # float [0, 1]

        # Make the image a cutout, instead of a blend
        out_filename = os.path.join(cur_render_dir, str(data['id'])+'.png')

        nonlocal root_render
        img = root_render.copy()

        # img_alpha = part_render[:, :, 3] * 1.0 / 256
        img_alpha = part_mask # H W
        all_white_alpha = 1.0 - img_alpha

        all_white = np.ones((img.shape), dtype=np.float32) * 255
        all_white[:, :, 0] *= all_white_alpha
        all_white[:, :, 1] *= all_white_alpha
        all_white[:, :, 2] *= all_white_alpha

        img[:, :, 0] *= img_alpha
        img[:, :, 1] *= img_alpha
        img[:, :, 2] *= img_alpha

        # out = img[:, :, :3] + all_white[:, :, :3] # white bg
        out = img[:, :, :3] # black bg
        out = (img * 255).astype(np.uint8)

        # part_mask = np.repeat(part_mask[:, :,np.newaxis], 3, axis=2)
        # part_mask = (part_mask * 255).astype(np.uint8)

        # out = out.astype(np.uint8)
        imageio.imwrite(out_filename, out)
        out_meta_fn = os.path.join(cur_render_dir, str(data['id'])+'.txt')
        with open(out_meta_fn, 'w') as fout:
            fout.write(u' '.join((str(data['id']), data['name'], data['text'])).strip())

        parttree.set_node_image(data['id'], out, out_filename)

        return masked_segments

    render(tree_hier)

    return parttree

def main(args):
    parttree = create_partnet_tree(args.partnet_dir, args.output)
    parttree.render_tree(os.path.join(parttree.output_dir, 'tree_with_images'))

    save_tree(parttree, os.path.join(parttree.output_dir, 'tree.pkl'))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # PartNet file
    parser.add_argument('--partnet_dir', type=str, default='/home/codeysun/git/data/PartNet/data_v0/')
    parser.add_argument('--output', type=str, default='10558')
    args = parser.parse_args()

    main(args)
